from .bert_post import get_bert_post, compute_bert_acc
from .lstm_post import get_lstm_post
from .bertcls_post import get_bertcls_post
import numpy as np


class_dict = {0:'163', 1:'7367', 2:'332', 3:'1970', 4:'4640', 5:'8629', 6:'6848', 7:'1088', 8:'460', 9:'6272', 10:'7312', 11:'2136', 12:'1867', 13:'669', 14:'3526', 15:'3664', 16:'3242', 17:'19', 18:'32', 19:'5789', 20:'118', 21:'226', 22:'7859', 23:'3947', 24:'1898', 25:'2416', 26:'1737', 27:'4680'}

def get_nlp_post(type, outputs, batch_size, params, content):
    return eval(type)(outputs, batch_size, params, content)


def get_tdnn_post(outputs, batch_size, params, content):
    outputs_size = params['outputs_size'].split("#")
    outputs_size_list = [ [int(size) for size in output_size.split(",")] for output_size in outputs_size]
    outputs = [outputs[i].copy().reshape(-1, outputs_size_list[i][0]) for i in range(len(outputs))]

    npreds = []
    for idx in range(batch_size):
        pred = np.argmax(outputs[0][idx])
        target = content[idx][0][:-4].split("/")[-1].split("-")[0]

        correct = 1 if target == class_dict[int(pred)] else 0
        npreds.append(correct)

    return npreds

def get_textcnn_post(outputs, batch_size, params, content):
    outputs_size = params['outputs_size'].split("#")
    outputs_size_list = [ [int(size) for size in output_size.split(",")] for output_size in outputs_size]
    outputs = [outputs[i].copy().reshape(-1, outputs_size_list[i][0]) for i in range(len(outputs))]

    npreds = []
    for idx in range(batch_size):
        pred = np.argmax(outputs[0][idx])
        target = content[idx][0][:-4].split("_")[-1]

        correct = 1 if int(target) == int(pred) else 0
        npreds.append(correct)

    return npreds

def get_bertsst2_post(outputs, batch_size, params, content):
    outputs_size = params['outputs_size'].split("#")
    outputs_size_list = [ [int(size) for size in output_size.split(",")] for output_size in outputs_size]
    outputs = [outputs[i].copy().reshape(-1, outputs_size_list[i][0]) for i in range(len(outputs))]

    npreds = []
    for idx in range(batch_size):
        pred = np.argmax(outputs[0][idx])
        target = content[idx][-1]

        correct = 1 if int(target) == int(pred) else 0
        npreds.append(correct)

    return npreds
